from sklearn import datasets
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

x_data = datasets.load_iris().data
y_data = datasets.load_iris().target

# 数据集乱序
np.random.seed(116)
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)

# 拆分训练集和测试集
x_train = x_data[:-30]  # dtype=float64
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]

x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)

# 配成 batch
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

# 定义神经网络种所有可训练参数
w1 = tf.Variable(tf.random.truncated_normal([4,3], stddev=0.1, seed=1)) # dtype=float32
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))

# 嵌套循环迭代
epoch = 500
lr = 0.1    # 如果学习率太小，可能陷入局部
train_loss_results = []
test_acc = []

for epoch in range(epoch) :
    loss_all = 0        # 所有样本的预测损失
    for step, (x_train, y_train) in enumerate(train_db) :
        with tf.GradientTape() as tape:
            y = tf.matmul(x_train, w1) + b1         # 一个batch的预测结果
            y = tf.nn.softmax(y)
            y_ = tf.one_hot(y_train, depth=3)
            loss = tf.reduce_mean(tf.square(y_-y))  # 一个batch的损失
            loss_all += loss.numpy()
        grads = tape.gradient(loss, [w1, b1])
        w1.assign_sub(lr * grads[0])
        b1.assign_sub(lr * grads[1])

    print("Epoch {}, loss: {}".format(epoch, loss_all/4))
    train_loss_results.append(loss_all/4)

    total_correct, total_number = 0, 0
    for x_test, y_test in test_db :
        y = tf.matmul(x_test, w1) + b1
        y = tf.nn.softmax(y)
        pred = tf.argmax(y, axis=1) # y是一个概率向量，最大的那个就是预测类目
        pred = tf.cast(pred, dtype=y_test.dtype)
        correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
        correct = tf.reduce_sum(correct)    # 一组batch的correct加起来
        total_correct += int(correct)
        total_number += x_test.shape[0]     # 实际上没必要累加，一开始就知道的

    acc = total_correct / total_number
    print("test_acc:", acc)
    test_acc.append(acc)

plt.title("Loss Function Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(train_loss_results, label="$Loss$")
plt.legend()
plt.show()

plt.title("Acc Curve")
plt.xlabel("Epoch")
plt.ylabel("Acc")
plt.plot(test_acc, label="$Accuracy$")
plt.legend()
plt.show()